from typing import List, Optional, Union
from iree import runtime as ireert
import re
import torch
import numpy as np

re_attention = re.compile(
    r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
    re.X,
)


def parse_prompt_attention(text):
    """
    Parses a string with attention tokens and returns a list of pairs:
        text and its associated weight.
    Accepted tokens are:
      (abc) - increases attention to abc by a multiplier of 1.1
      (abc:3.12) - increases attention to abc by a multiplier of 3.12
      [abc] - decreases attention to abc by a multiplier of 1.1
      \( - literal character '('
      \[ - literal character '['
      \) - literal character ')'
      \] - literal character ']'
      \\ - literal character '\'
      anything else - just text
    >>> parse_prompt_attention('normal text')
    [['normal text', 1.0]]
    >>> parse_prompt_attention('an (important) word')
    [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
    >>> parse_prompt_attention('(unbalanced')
    [['unbalanced', 1.1]]
    >>> parse_prompt_attention('\(literal\]')
    [['(literal]', 1.0]]
    >>> parse_prompt_attention('(unnecessary)(parens)')
    [['unnecessaryparens', 1.1]]
    >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
    [['a ', 1.0],
     ['house', 1.5730000000000004],
     [' ', 1.1],
     ['on', 1.0],
     [' a ', 1.1],
     ['hill', 0.55],
     [', sun, ', 1.1],
     ['sky', 1.4641000000000006],
     ['.', 1.1]]
    """

    res = []
    round_brackets = []
    square_brackets = []

    round_bracket_multiplier = 1.1
    square_bracket_multiplier = 1 / 1.1

    def multiply_range(start_position, multiplier):
        for p in range(start_position, len(res)):
            res[p][1] *= multiplier

    for m in re_attention.finditer(text):
        text = m.group(0)
        weight = m.group(1)

        if text.startswith("\\"):
            res.append([text[1:], 1.0])
        elif text == "(":
            round_brackets.append(len(res))
        elif text == "[":
            square_brackets.append(len(res))
        elif weight is not None and len(round_brackets) > 0:
            multiply_range(round_brackets.pop(), float(weight))
        elif text == ")" and len(round_brackets) > 0:
            multiply_range(round_brackets.pop(), round_bracket_multiplier)
        elif text == "]" and len(square_brackets) > 0:
            multiply_range(square_brackets.pop(), square_bracket_multiplier)
        else:
            res.append([text, 1.0])

    for pos in round_brackets:
        multiply_range(pos, round_bracket_multiplier)

    for pos in square_brackets:
        multiply_range(pos, square_bracket_multiplier)

    if len(res) == 0:
        res = [["", 1.0]]

    # merge runs of identical weights
    i = 0
    while i + 1 < len(res):
        if res[i][1] == res[i + 1][1]:
            res[i][0] += res[i + 1][0]
            res.pop(i + 1)
        else:
            i += 1

    return res


def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
    r"""
    Tokenize a list of prompts and return its tokens with weights of each token.
    No padding, starting or ending token is included.
    """
    tokens = []
    weights = []
    truncated = False
    for text in prompt:
        texts_and_weights = parse_prompt_attention(text)
        text_token = []
        text_weight = []
        for word, weight in texts_and_weights:
            # tokenize and discard the starting and the ending token
            token = pipe.tokenizer(word).input_ids[1:-1]
            text_token += token
            # copy the weight by length of token
            text_weight += [weight] * len(token)
            # stop if the text is too long (longer than truncation limit)
            if len(text_token) > max_length:
                truncated = True
                break
        # truncate
        if len(text_token) > max_length:
            truncated = True
            text_token = text_token[:max_length]
            text_weight = text_weight[:max_length]
        tokens.append(text_token)
        weights.append(text_weight)
    if truncated:
        print(
            "Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
        )
    return tokens, weights


def pad_tokens_and_weights(
    tokens,
    weights,
    max_length,
    bos,
    eos,
    no_boseos_middle=True,
    chunk_length=77,
):
    r"""
    Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
    """
    max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
    weights_length = (
        max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
    )
    for i in range(len(tokens)):
        tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
        if no_boseos_middle:
            weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
        else:
            w = []
            if len(weights[i]) == 0:
                w = [1.0] * weights_length
            else:
                for j in range(max_embeddings_multiples):
                    w.append(1.0)  # weight for starting token in this chunk
                    w += weights[i][
                        j
                        * (chunk_length - 2) : min(
                            len(weights[i]), (j + 1) * (chunk_length - 2)
                        )
                    ]
                    w.append(1.0)  # weight for ending token in this chunk
                w += [1.0] * (weights_length - len(w))
            weights[i] = w[:]

    return tokens, weights


def get_unweighted_text_embeddings(
    pipe,
    text_input,
    chunk_length: int,
    no_boseos_middle: Optional[bool] = True,
):
    """
    When the length of tokens is a multiple of the capacity of the text encoder,
    it should be split into chunks and sent to the text encoder individually.
    """
    max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
    if max_embeddings_multiples > 1:
        text_embeddings = []
        for i in range(max_embeddings_multiples):
            # extract the i-th chunk
            text_input_chunk = text_input[
                :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
            ].clone()

            # cover the head and the tail by the starting and the ending tokens
            text_input_chunk[:, 0] = text_input[0, 0]
            text_input_chunk[:, -1] = text_input[0, -1]

            text_embedding = pipe.run("clip", text_input_chunk)[0].to_host()

            if no_boseos_middle:
                if i == 0:
                    # discard the ending token
                    text_embedding = text_embedding[:, :-1]
                elif i == max_embeddings_multiples - 1:
                    # discard the starting token
                    text_embedding = text_embedding[:, 1:]
                else:
                    # discard both starting and ending tokens
                    text_embedding = text_embedding[:, 1:-1]

            text_embeddings.append(text_embedding)
        # AMDSHARK: Convert the result to tensor
        # text_embeddings = torch.concat(text_embeddings, axis=1)
        text_embeddings_np = np.concatenate(np.array(text_embeddings))
        text_embeddings = torch.from_numpy(text_embeddings_np)
    else:
        text_embeddings = pipe.run("clip", text_input)[0]
        text_embeddings = torch.from_numpy(text_embeddings.to_host())
    return text_embeddings


# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
    return [[49407 if token is None else token for token in tokens[0]]]


def get_weighted_text_embeddings(
    pipe,
    prompt: List[str],
    uncond_prompt: List[str] = None,
    max_embeddings_multiples: Optional[int] = 8,
    no_boseos_middle: Optional[bool] = True,
    skip_parsing: Optional[bool] = False,
    skip_weighting: Optional[bool] = False,
):
    max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2

    if not skip_parsing:
        prompt_tokens, prompt_weights = get_prompts_with_weights(
            pipe, prompt, max_length - 2
        )
        if uncond_prompt is not None:
            uncond_tokens, uncond_weights = get_prompts_with_weights(
                pipe, uncond_prompt, max_length - 2
            )
    else:
        prompt_tokens = [
            token[1:-1]
            for token in pipe.tokenizer(
                prompt, max_length=max_length, truncation=True
            ).input_ids
        ]
        prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
        if uncond_prompt is not None:
            if isinstance(uncond_prompt, str):
                uncond_prompt = [uncond_prompt]
            uncond_tokens = [
                token[1:-1]
                for token in pipe.tokenizer(
                    uncond_prompt, max_length=max_length, truncation=True
                ).input_ids
            ]
            uncond_weights = [[1.0] * len(token) for token in uncond_tokens]

    # round up the longest length of tokens to a multiple of (model_max_length - 2)
    max_length = max([len(token) for token in prompt_tokens])
    if uncond_prompt is not None:
        max_length = max(max_length, max([len(token) for token in uncond_tokens]))
    max_embeddings_multiples = min(
        max_embeddings_multiples,
        (max_length - 1) // (pipe.model_max_length - 2) + 1,
    )
    max_embeddings_multiples = max(1, max_embeddings_multiples)

    max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2

    # pad the length of tokens and weights
    bos = pipe.tokenizer.bos_token_id
    eos = pipe.tokenizer.eos_token_id
    prompt_tokens, prompt_weights = pad_tokens_and_weights(
        prompt_tokens,
        prompt_weights,
        max_length,
        bos,
        eos,
        no_boseos_middle=no_boseos_middle,
        chunk_length=pipe.model_max_length,
    )

    # FIXME: This is a hacky fix caused by tokenizer padding with None values
    prompt_tokens = filter_nonetype_tokens(prompt_tokens)

    # prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
    prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
    if uncond_prompt is not None:
        uncond_tokens, uncond_weights = pad_tokens_and_weights(
            uncond_tokens,
            uncond_weights,
            max_length,
            bos,
            eos,
            no_boseos_middle=no_boseos_middle,
            chunk_length=pipe.model_max_length,
        )

        # FIXME: This is a hacky fix caused by tokenizer padding with None values
        uncond_tokens = filter_nonetype_tokens(uncond_tokens)

        # uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
        uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu")

    # get the embeddings
    text_embeddings = get_unweighted_text_embeddings(
        pipe,
        prompt_tokens,
        pipe.model_max_length,
        no_boseos_middle=no_boseos_middle,
    )
    # prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
    prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu")
    if uncond_prompt is not None:
        uncond_embeddings = get_unweighted_text_embeddings(
            pipe,
            uncond_tokens,
            pipe.model_max_length,
            no_boseos_middle=no_boseos_middle,
        )
        # uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
        uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu")

    # assign weights to the prompts and normalize in the sense of mean
    # TODO: should we normalize by chunk or in a whole (current implementation)?
    if (not skip_parsing) and (not skip_weighting):
        previous_mean = (
            text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
        )
        text_embeddings *= prompt_weights.unsqueeze(-1)
        current_mean = (
            text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
        )
        text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
        if uncond_prompt is not None:
            previous_mean = (
                uncond_embeddings.float()
                .mean(axis=[-2, -1])
                .to(uncond_embeddings.dtype)
            )
            uncond_embeddings *= uncond_weights.unsqueeze(-1)
            current_mean = (
                uncond_embeddings.float()
                .mean(axis=[-2, -1])
                .to(uncond_embeddings.dtype)
            )
            uncond_embeddings *= (
                (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
            )

    if uncond_prompt is not None:
        return text_embeddings, uncond_embeddings
    return text_embeddings, None
